5ace1f
@@ -18,7 +18,6 @@
 package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
 
 import java.util.ArrayList;
-import java.util.EnumSet;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -32,8 +31,6 @@
 import org.apache.calcite.rel.core.TableScan;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.sql.SqlKind;
@@ -51,13 +48,7 @@
 
   protected static final Logger LOG = LoggerFactory.getLogger(HivePreFilteringRule.class);
 
-  private static final Set<SqlKind>        COMPARISON = EnumSet.of(SqlKind.EQUALS,
-                                                          SqlKind.GREATER_THAN_OR_EQUAL,
-                                                          SqlKind.LESS_THAN_OR_EQUAL,
-                                                          SqlKind.GREATER_THAN, SqlKind.LESS_THAN,
-                                                          SqlKind.NOT_EQUALS);
-
-  private final FilterFactory              filterFactory;
+  private final FilterFactory filterFactory;
 
   // Max number of nodes when converting to CNF
   private final int maxCNFNodeCount;
@@ -120,7 +111,7 @@
public void onMatch(RelOptRuleCall call) {
 
       for (RexNode operand : operands) {
         if (operand.getKind() == SqlKind.OR) {
-          extractedCommonOperands = extractCommonOperands(rexBuilder, operand, maxCNFNodeCount);
+          extractedCommonOperands = extractCommonOperands(rexBuilder, filter.getInput(), operand, maxCNFNodeCount);
           for (RexNode extractedExpr : extractedCommonOperands) {
             if (operandsToPushDownDigest.add(extractedExpr.toString())) {
               operandsToPushDown.add(extractedExpr);
@@ -155,7 +146,7 @@
public void onMatch(RelOptRuleCall call) {
       break;
 
     case OR:
-      operandsToPushDown = extractCommonOperands(rexBuilder, topFilterCondition, maxCNFNodeCount);
+      operandsToPushDown = extractCommonOperands(rexBuilder, filter.getInput(), topFilterCondition, maxCNFNodeCount);
       break;
     default:
       return;
@@ -191,8 +182,8 @@
public void onMatch(RelOptRuleCall call) {
 
   }
 
-  private static List<RexNode> extractCommonOperands(RexBuilder rexBuilder, RexNode condition,
-          int maxCNFNodeCount) {
+  private static List<RexNode> extractCommonOperands(RexBuilder rexBuilder, RelNode input,
+      RexNode condition, int maxCNFNodeCount) {
     assert condition.getKind() == SqlKind.OR;
     Multimap<String, RexNode> reductionCondition = LinkedHashMultimap.create();
 
@@ -216,27 +207,12 @@
public void onMatch(RelOptRuleCall call) {
           return new ArrayList<>();
         }
         RexCall conjCall = (RexCall) conjunction;
-        RexNode ref = null;
-        if (COMPARISON.contains(conjCall.getOperator().getKind())) {
-          if (conjCall.operands.get(0) instanceof RexInputRef
-              && conjCall.operands.get(1) instanceof RexLiteral) {
-            ref = conjCall.operands.get(0);
-          } else if (conjCall.operands.get(1) instanceof RexInputRef
-              && conjCall.operands.get(0) instanceof RexLiteral) {
-            ref = conjCall.operands.get(1);
-          } else {
-            // We do not know what it is, we bail out for safety
-            return new ArrayList<>();
-          }
-        } else if (conjCall.getOperator().getKind().equals(SqlKind.IN)) {
-          ref = conjCall.operands.get(0);
-        } else if (conjCall.getOperator().getKind().equals(SqlKind.BETWEEN)) {
-          ref = conjCall.operands.get(1);
-        } else {
+        Set<Integer> refs = HiveCalciteUtil.getInputRefs(conjCall);
+        if (refs.size() != 1) {
           // We do not know what it is, we bail out for safety
           return new ArrayList<>();
         }
-
+        RexNode ref = rexBuilder.makeInputRef(input, refs.iterator().next());
         String stringRef = ref.toString();
         reductionCondition.put(stringRef, conjCall);
         refsInCurrentOperand.add(stringRef);
